import gradio as gr
from app.app_init import device, model_L14, preprocess_L14, model_H14, preprocess_H14
from app.gradio.constants import description
from app.app_init import substitute_dict
from pymilvus import Collection
from itertools import groupby

from app.utll.embedding_util import get_embedding


def text2image_gr_l2():
    def clip_api(query_param='', return_num=8, top_k_num=8, thresh_hold=1.1, user_id="1", model_name="ViT-H-14"):

        if model_name == "ViT-L-14":
            collection_name = "multimodal_search_L14_L2"
            model = model_L14
            preprocess = preprocess_L14
        elif model_name == "ViT-H-14":
            collection_name = "multimodal_search_H14_L2"
            model = model_H14
            preprocess = preprocess_H14

        print("当前使用的模型为：", model_name)

        milvus_conn = Collection(collection_name)
        model_infos = {"model": model, "preprocess": preprocess, "device": device}

        print("原始搜索词为：", query_param)

        # 根据搜索词查找替换词，如果有替换词，那么使用替换词进行搜索
        substitute_words = substitute_dict.get(query_param)
        if substitute_words:
            query_param = substitute_words

        print("本次搜索最终词语为:", query_param)

        query_embs = get_embedding("text", query_param, model_infos, batch_mode=True)

        for query_emb in query_embs:
            query_item = {
                "data": [query_emb],
                "anns_field": "vector",
                "param": {"metric_type": "L2", "params": {"nprobe": 10}, "offset": 0},
                "limit": return_num,
                "expr": f"(user_id=='{user_id}')",
                "output_fields": ["id", "user_id", "image_path"]
            }

            print(f"query_item: {query_item}")
            search_result = milvus_conn.search(**query_item)

            data = [hit.entity.to_dict() for hit in search_result[0]]
            # print(data)
            results = []
            for row in data:
                item = {}
                item["id"] = row["entity"]["id"]
                item["user_id"] = row["entity"]["user_id"]
                item["image_path"] = row["entity"]["image_path"]
                item["score"] = row["distance"]
                if item["score"] < thresh_hold:
                    continue  # filter
                results.append(item)

        filtered_result = sorted([max(v, key=lambda x: x["score"]) for k, v in
                                  groupby(sorted(results, key=lambda item: item['id']),
                                          key=lambda x: x["id"])],
                                 key=lambda row: row["score"], reverse=True)[:int(top_k_num)]

        print("filtered_result:", filtered_result)

        # 使用列表推导式获取全部的image_path
        image_paths_list = [item['image_path'] for item in filtered_result]

        return image_paths_list

    examples = [
        ["广州塔", 20, 16, 1.53, "1", "ViT-H-14"],
        ["欧式", 20, 16, 1.53, "1", "ViT-H-14"],
        ["夜景", 20, 16, 1.53, "1", "ViT-H-14"],
        ["桥梁", 20, 16, 1.53, "1", "ViT-H-14"],
    ]

    title = "<h1 align='center'>基于AI大模型和向量数据库的图片搜索引擎</h1>"

    with gr.Blocks() as image_block:
        gr.Markdown(title)
        gr.Markdown(description)
        with gr.Row():
            with gr.Column(scale=1):
                with gr.Column(scale=2):
                    query_param_text = gr.Textbox(value="花草树木", label="填写文本", elem_id=0, interactive=True)
                num = gr.components.Slider(minimum=0, maximum=200, step=1, value=50, label="返回图片数",
                                           elem_id=1)
                top_k = gr.components.Slider(minimum=0, maximum=200, step=1, value=20,
                                             label="top_k：召回的图片里面，相似分大于设定阈值的，前K张图片",
                                             elem_id=2)
                distance_thresh_hold = gr.components.Slider(minimum=0, maximum=5, step=0.01, value=1.53,
                                                         label="距离",
                                                         elem_id=3)
                user_id_text = gr.Textbox(value="1", label="填写user_id", elem_id=4, interactive=True)
                model_name_radio = gr.inputs.Radio(["ViT-H-14"], default="ViT-H-14", label="模型选择")
                btn = gr.Button("搜索", )
            with gr.Column(scale=100):
                out = gr.Gallery(label="检索结果为：").style(grid=4, height=600)
        inputs = [query_param_text, num, top_k, distance_thresh_hold, user_id_text, model_name_radio]
        btn.click(fn=clip_api, inputs=inputs, outputs=out)
        gr.Examples(examples, inputs=inputs)
    return image_block


if __name__ == "__main__":
    with gr.TabbedInterface(
            [text2image_gr_l2()],
            ["文到图搜索L2(distance)"],
    ) as app:
        app.launch(
            enable_queue=True,
        )
